-
Notifications
You must be signed in to change notification settings - Fork 207
Vbert #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Vbert #339
Conversation
e77142f to
5c11cd3
Compare
* modeling * update modeling * update token id default * init files * remove vllama + update torch lower bound for cpu * back to normal transformer bound * clean * Update colpali_engine/models/__init__.py --------- Co-authored-by: QuentinJGMace <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly comments about the form, overall LGTM!
| ColQwen2_5Omni, | ||
| ColQwen2_5OmniProcessor, | ||
| # ColQwen2_5Omni, | ||
| # ColQwen2_5OmniProcessor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment to the README if ColQwen 2.5 Omni is not supported anymore
|
|
||
| # Process queries. | ||
| queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented lines if not useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually usefull, in modernvbert self.processor.query_prefix is "" but it is useful if somebody wants to reproduce other older models.
Thanks for flagging it out !
| # Process queries. | ||
| queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| # queries = [self.processor.query_prefix + q + self.processor.query_augmentation_token * 10 for q in queries] | ||
| queries = [q + self.processor.query_augmentation_token * 10 for q in queries] if is_str else queries |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
put 10 into a constant (e.g. N_AUGMENTATION_TOKENS)
| else: | ||
| proc_batch[k] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unnecessary
| query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]) | ||
| query_outputs = model(**{k[6:]: v for k, v in inputs.items() if k.startswith("query")}) | ||
| # feed only kwargs with 'doc_' prefix | ||
| doc_outputs = model(**{k[4:]: v for k, v in inputs.items() if k.startswith("doc")}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define var/constant for len("doc:")
| """ | ||
| Helper function to reshape negative doc inputs to (batch_size * num_neg_docs, ...) | ||
| """ | ||
| neg_doc_inputs = {k[8:]: v for k, v in inputs.items() if k.startswith("neg_doc")} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define var/constant for 8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could rename variables for more clarity and use constants, and add doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save as test_bi_losses
| assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}" | ||
|
|
||
| # # Check if the maximum scores per row are in the diagonal of the matrix score | ||
| # assert (scores.argmax(dim=1) == torch.arange(len(ds), device=scores.device)).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this commented out?
Still to do:
Modify all negatives loss, not just the ones used for vbert